#!/bin/bash
# PyTorch训练启动脚本

set -e

# 配置参数
NUM_GPUS=${NUM_GPUS:-8}
CONFIG_FILE=${CONFIG_FILE:-"../../configs/pytorch.json"}
MASTER_PORT=${MASTER_PORT:-29500}

echo "=========================================="
echo "PyTorch Baseline Training"
echo "=========================================="
echo "Number of GPUs: $NUM_GPUS"
echo "Config file: $CONFIG_FILE"
echo "Master port: $MASTER_PORT"
echo "=========================================="

# 检查配置文件是否存在
if [ ! -f "$CONFIG_FILE" ]; then
    echo "Error: Config file $CONFIG_FILE not found!"
    exit 1
fi

# 单卡训练
if [ "$NUM_GPUS" -eq 1 ]; then
    echo "Running single-GPU training..."
    python train_pytorch.py --config "$CONFIG_FILE"
else
    # 多卡DDP训练
    echo "Running multi-GPU DDP training with $NUM_GPUS GPUs..."
    torchrun \
        --nproc_per_node=$NUM_GPUS \
        --master_port=$MASTER_PORT \
        train_pytorch.py \
        --config "$CONFIG_FILE"
fi

echo "=========================================="
echo "Training completed!"
echo "=========================================="

